{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s8Qn4neXQY4o"
      },
      "source": [
        "# SpICE Classifier model (Keras Usage)\n",
        "Licensed under the Apache License, Version 2.0\n",
        "\n",
        "\n",
        "Paper: Speech Intelligibility Classifiers from Half-a-Million Utterances\n",
        "\n",
        "This colab walks through how to download and use the SpICE wav2vec2 based speech intelligibility classifier. This colab walks you through how to use the model on a sample audio file. \n",
        "\n",
        "You'll first load the wav2vec2 model from HuggingFace and then use the SpICE classifier head to generate predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3j37CyrWYuwk"
      },
      "outputs": [],
      "source": [
        "#@title Installation\n",
        "!pip install transformers[tf-cpu]\n",
        "!pip install -U -q PyDrive"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VJ4PjhssKrlv"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "import pickle\n",
        "import soundfile as sf\n",
        "import librosa\n",
        "import scipy.io.wavfile as wav\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import transformers\n",
        "from transformers import Wav2Vec2Processor, Wav2Vec2CTCTokenizer, TFWav2Vec2Model\n",
        "\n",
        "import IPython\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from google.colab import drive\n",
        "\n",
        "SPEECH_URL = \"https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\"  # noqa: E501\n",
        "SPEECH_FILE = \"_assets/speech.wav\"\n",
        "EXP_SAMPLE_RATE = 16000\n",
        "\n",
        "if not os.path.exists(SPEECH_FILE):\n",
        "    os.makedirs(\"_assets\", exist_ok=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5rexoNRXJ-G"
      },
      "outputs": [],
      "source": [
        "!wget \"https${SPEECH_URL}\"\n",
        "!mv Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav _assets/speech.wav"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XEcNVJXAmEe2"
      },
      "outputs": [],
      "source": [
        "#@title download wav2vec2 TF model\n",
        "tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
        "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
        "hf_w2v2_model = TFWav2Vec2Model.from_pretrained(\"facebook/wav2vec2-base-960h\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zuWBcwb_Y0LC"
      },
      "outputs": [],
      "source": [
        "#@title Load Keras model from Hub\n",
        "spice_w2v2_cls_model = hub.KerasLayer('https://tfhub.dev/google/euphonia_spice/classification/1')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "XIo0JueKY-Yy"
      },
      "outputs": [],
      "source": [
        "#@title helpers for reading wav\n",
        "# To read and write WAV soundfiles as np arrays of floats.\n",
        "def wavread(filename):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  # Read in wav file.\n",
        "  with open(filename, 'rb') as file_handle:\n",
        "    samplerate, wave_data = wav.read(file_handle)\n",
        "  # Normalize short ints to floats in range [-1..1).\n",
        "  data = np.asfarray(wave_data) / 32768.0\n",
        "  return data, samplerate\n",
        "\n",
        "def resample_aud(audio, #: np.ndarray,\n",
        "                 sample_rate,  # : int,\n",
        "                 target_sr=16000): #: int) -\u003e np.ndarray:\n",
        "  \"\"\"Resample audio to target.\"\"\"\n",
        "  # print(audio.shape, sample_rate, target_sr)\n",
        "  return librosa.core.resample(\n",
        "      audio, orig_sr=sample_rate, target_sr=target_sr, res_type='kaiser_best')\n",
        "\n",
        "def read_wav_resample(filename):\n",
        "  audio, sample_rate = wavread(filename)\n",
        "  # print(audio.shape)\n",
        "  # Resample, if necessary.\n",
        "  if sample_rate != 16000:\n",
        "    audio = resample_aud(\n",
        "        audio, sample_rate, target_sr=16000)\n",
        "  if audio.dtype != 'float32':\n",
        "    audio = np.array(audio, dtype=np.float32)\n",
        "  return audio\n",
        "\n",
        "def samples_to_embedding_hfw2v2(model_input, #: tf.Tensor,\n",
        "                                hfw2v2, #: Tuple[Any, Any], processor, model\n",
        "                                sample_rate=16000, #: float,\n",
        "                                name=None): #, #: Optional[str] = None) -\u003e np.ndarray:\n",
        "  \"\"\"Run inference to map audio samples to hf2v2 model embedding.\"\"\"\n",
        "  processor, model = hfw2v2\n",
        "\n",
        "  print('[samples_to_embedding_hfw2v2] %s: Started inference.', name)\n",
        "  if not tf.is_tensor(model_input):\n",
        "    model_input = tf.convert_to_tensor(model_input)\n",
        "  if model_input.dtype != tf.float32:\n",
        "    raise ValueError(f'hfw2v2 takes floats: {model_input.dtype}')\n",
        "  # NOTE: Model does normalization in tokenizer, checking input in [-1, 1]\n",
        "  # is perhaps not necessary.\n",
        "  if model_input.shape.rank \u003e 1:\n",
        "    model_input = tf.squeeze(model_input)\n",
        "  model_input.shape.assert_has_rank(1)\n",
        "\n",
        "  # Now actually run the inference.\n",
        "  input_values = processor(\n",
        "      model_input.numpy(), sampling_rate=sample_rate,\n",
        "      return_tensors='tf').input_values\n",
        "  final_ret = model(input_values).last_hidden_state\n",
        "  final_ret.shape.assert_has_rank(3)\n",
        "  emb = final_ret.numpy()\n",
        "\n",
        "  return emb\n",
        "\n",
        "def get_prediction(wav_file,\n",
        "                   hfw2v2=(processor, hf_w2v2_model),\n",
        "                   w2v2_model=spice_w2v2_cls_model):\n",
        "  audio = read_wav_resample(wav_file)\n",
        "  w2v2_emb = samples_to_embedding_hfw2v2(audio, hfw2v2=hfw2v2)\n",
        "  prediction = w2v2_model(w2v2_emb)[0]\n",
        "  return prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IGSgalAdZa03"
      },
      "outputs": [],
      "source": [
        "#@title Run on sample speech file\n",
        "# Expected: [samples_to_embedding_hfw2v2] %s: Started inference. None\n",
        "# \u003ctf.Tensor: shape=(5,), dtype=float32, numpy=\n",
        "# array([1.0000000e+00, 6.7461668e-21, 1.3921502e-21, 6.2975914e-24,\n",
        "#        0.0000000e+00], dtype=float32)\u003e\n",
        "get_prediction(SPEECH_FILE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "meJM5KUsMf95"
      },
      "source": [
        "## Other checks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T90lWpjsMoF-"
      },
      "outputs": [],
      "source": [
        "out = samples_to_embedding_hfw2v2(np.ones([16000], dtype='float32'),\n",
        "                                  hfw2v2=(processor, hf_w2v2_model))\n",
        "print(np.shape(out))\n",
        "print(np.mean(out), np.min(out), np.max(out))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NVLaaytiNO5N"
      },
      "outputs": [],
      "source": [
        "# Expected: array([[9.9801159e-01, 3.9059367e-05, 1.9493260e-03, 3.9188830e-09,\n",
        "#        9.7403199e-20]], dtype=float32)\u003e\n",
        "spice_w2v2_cls_model(out)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1RILNyGUS0IVnICZSqs-IHpFhe45W8s5s",
          "timestamp": 1674514931350
        },
        {
          "file_id": "19OL5PBwrXiPx-qR3DkJAxhjTLV_JcDyT",
          "timestamp": 1674502777471
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
